"""
Phase 6: Global vs Domestic Demographic Decomposition
Separates Z into GDP-weighted global component and country deviation to test
whether domestic aging or global demographic trends drive rates and inflation.

Outputs:
  - Table 10: Domestic vs global Z on rates/inflation (phase6_table10_global_domestic.md)
  - Table 11: Time-series world average regressions (phase6_table11_world_ts.md)
"""

import sys
from pathlib import Path
import numpy as np
import pandas as pd

# ── paths ────────────────────────────────────────────────────────────────────
PROJECT = Path("/mnt/c/demographics_capital_flows")
sys.path.insert(0, str(PROJECT / "multilateral" / "src"))
from model import PanelGLS

DATA_PATH = PROJECT / "monetary" / "data" / "processed" / "monetary_panel.csv"
TABLE_DIR = PROJECT / "monetary" / "output" / "tables"
TABLE_DIR.mkdir(parents=True, exist_ok=True)

OECD_38 = [
    "AUS", "AUT", "BEL", "CAN", "CHL", "COL", "CRI", "CZE", "DNK", "EST",
    "FIN", "FRA", "DEU", "GRC", "HUN", "ISL", "IRL", "ISR", "ITA", "JPN",
    "KOR", "LVA", "LTU", "LUX", "MEX", "NLD", "NZL", "NOR", "POL", "PRT",
    "SVK", "SVN", "ESP", "SWE", "CHE", "TUR", "GBR", "USA",
]


# ── helpers ──────────────────────────────────────────────────────────────────
def run_model(panel, dep_var, rhs_vars):
    """Run PanelGLS; return dict with coefs/se/pvals/r2/nobs or None."""
    cols = [dep_var] + rhs_vars + ["iso3", "year"]
    df = panel[cols].dropna()
    if len(df) < 50:
        print(f"  SKIP {dep_var}: only {len(df)} obs after dropna")
        return None
    try:
        y = df[dep_var].values
        X = df[rhs_vars].values
        gls = PanelGLS()
        gls.fit(y, X, df["iso3"].values, df["year"].values)
        return {
            "coefs": dict(zip(rhs_vars, gls.beta)),
            "se": dict(zip(rhs_vars, gls.se)),
            "pvals": dict(zip(rhs_vars, gls.pvalues)),
            "r2": gls.r_squared,
            "nobs": gls.n_obs,
            "ncountries": gls.n_countries,
        }
    except Exception as e:
        print(f"  ERROR {dep_var}: {e}")
        return None


def star(p):
    """Significance stars."""
    if p < 0.01:
        return "***"
    elif p < 0.05:
        return "**"
    elif p < 0.10:
        return "*"
    return ""


def fmt_coef(res, var):
    """Format coefficient as 'coef(se)***'."""
    if res is None or var not in res["coefs"]:
        return ""
    c = res["coefs"][var]
    s = res["se"][var]
    p = res["pvals"][var]
    return f"{c:.3f}{star(p)} ({s:.3f})"


def build_markdown_table(title, row_vars, col_labels, results, footer_rows):
    """Build a pipe-format markdown table."""
    lines = [f"# {title}\n"]
    header = "| Variable | " + " | ".join(col_labels) + " |"
    sep = "|" + "---|" * (len(col_labels) + 1)
    lines.append(header)
    lines.append(sep)
    for var in row_vars:
        cells = [fmt_coef(r, var) for r in results]
        lines.append(f"| {var} | " + " | ".join(cells) + " |")
    lines.append(sep)
    for label, vals in footer_rows:
        lines.append(f"| {label} | " + " | ".join(vals) + " |")
    return "\n".join(lines)


def run_ols_hac(y, X, max_lag=None):
    """OLS with Newey-West HAC standard errors.

    Returns dict with coefs, se, pvals, r2, nobs.
    """
    from scipy import stats as sp_stats

    n, k = X.shape
    if max_lag is None:
        max_lag = int(np.floor(4 * (n / 100) ** (2 / 9)))

    # OLS estimates
    XtX_inv = np.linalg.inv(X.T @ X)
    beta = XtX_inv @ (X.T @ y)
    resid = y - X @ beta

    # Newey-West HAC covariance
    S = np.zeros((k, k))
    for j in range(max_lag + 1):
        if j == 0:
            Gamma_j = (X * resid[:, None]).T @ (X * resid[:, None]) / n
        else:
            w = 1 - j / (max_lag + 1)  # Bartlett kernel
            e_X = X * resid[:, None]
            Gamma_j = (e_X[j:].T @ e_X[:-j]) / n
            Gamma_j = w * (Gamma_j + Gamma_j.T)
        S += Gamma_j

    V_hac = n * XtX_inv @ S @ XtX_inv
    se = np.sqrt(np.diag(V_hac))

    # t-stats and p-values
    t_stats = beta / se
    pvals = 2 * (1 - sp_stats.t.cdf(np.abs(t_stats), df=n - k))

    # R-squared
    ss_res = np.sum(resid ** 2)
    ss_tot = np.sum((y - y.mean()) ** 2)
    r2 = 1 - ss_res / ss_tot

    return {
        "coefs": beta,
        "se": se,
        "pvals": pvals,
        "r2": r2,
        "nobs": n,
        "resid": resid,
    }


# ── main ─────────────────────────────────────────────────────────────────────
def main():
    print("=" * 70)
    print("PHASE 6: GLOBAL VS DOMESTIC DEMOGRAPHIC DECOMPOSITION")
    print("=" * 70)

    panel = pd.read_csv(DATA_PATH)
    print(f"Panel: {len(panel):,} obs, {panel['iso3'].nunique()} countries")

    oecd = panel[panel["iso3"].isin(OECD_38)].copy()
    print(f"OECD:  {len(oecd):,} obs, {oecd['iso3'].nunique()} countries")

    # Verify global/domestic Z columns exist
    for p in [1, 2, 3]:
        for col in [f"global_Z_{p}", f"domestic_Z_{p}_dev"]:
            n = panel[col].notna().sum()
            print(f"  {col}: {n:,} non-missing")

    global_z = ["global_Z_1", "domestic_Z_1_dev",
                "global_Z_2", "domestic_Z_2_dev",
                "global_Z_3", "domestic_Z_3_dev"]
    controls_rate = ["rgdp_growth", "inflation", "fiscal_bal_gdp", "kaopen", "nfa_gdp_lag"]
    controls_infl = ["rgdp_growth", "output_gap", "fiscal_bal_gdp", "kaopen", "nfa_gdp_lag"]

    # ── Standardize both global and domestic Z for comparable coefficients ──
    # NOTE: global_Z_p (GDP-weighted mean of Z polynomials) is numerically ~0
    # because Z is constructed from demeaned age shares. The "domestic deviation"
    # is therefore essentially Z itself. We standardize domestic_Z_dev so its
    # coefficient represents "per 1-SD change" for comparability.
    for p in [1, 2, 3]:
        dvar = f"domestic_Z_{p}_dev"
        sd_d = panel[dvar].std()
        if sd_d > 0:
            panel[f"domestic_Z_{p}_dev_std"] = panel[dvar] / sd_d
            oecd[f"domestic_Z_{p}_dev_std"] = oecd[dvar] / sd_d
            print(f"  {dvar}: SD={sd_d:.4f}, standardized")

    # Use domestic_Z_dev_std (standardized) since global_Z is numerically zero
    domestic_z_std = ["domestic_Z_1_dev_std",
                      "domestic_Z_2_dev_std",
                      "domestic_Z_3_dev_std"]

    # ── TABLE 10: Domestic Demographics (Standardized) on Rates and Inflation ──
    # NOTE: global_Z is numerically ~0 because Z polynomials are constructed from
    # demeaned age shares, so their GDP-weighted cross-country mean is zero by
    # construction. Table 10 therefore reports standardized domestic deviations
    # (which are essentially just standardized Z). Year fixed effects in PanelGLS
    # absorb the common global demographic trend.
    print("\n── Table 10: Domestic Z (standardized) on Rates and Inflation ──")

    dvs_rate = ["real_bond_10y", "real_short_3m"]
    dvs_infl = ["inflation"]

    results_t10 = []
    col_labels_t10 = []

    for dv in dvs_rate:
        rhs = domestic_z_std + controls_rate
        res_full = run_model(panel, dv, rhs)
        res_oecd = run_model(oecd, dv, rhs)
        fn = res_full['nobs'] if res_full else 0
        on = res_oecd['nobs'] if res_oecd else 0
        if fn == on and fn > 0:
            results_t10.append(res_full)
            col_labels_t10.append(f"{dv} (OECD*)")
            if res_full:
                print(f"  {dv} (OECD*): domestic_Z_1_dev_std={res_full['coefs']['domestic_Z_1_dev_std']:.3f} "
                      f"(p={res_full['pvals']['domestic_Z_1_dev_std']:.3f})")
        else:
            for label, res in [(f"{dv} (Full)", res_full), (f"{dv} (OECD)", res_oecd)]:
                results_t10.append(res)
                col_labels_t10.append(label)

    for dv in dvs_infl:
        rhs = domestic_z_std + controls_infl
        res_full = run_model(panel, dv, rhs)
        res_oecd = run_model(oecd, dv, rhs)
        fn = res_full['nobs'] if res_full else 0
        on = res_oecd['nobs'] if res_oecd else 0
        if fn == on and fn > 0:
            results_t10.append(res_full)
            col_labels_t10.append(f"{dv} (OECD*)")
            if res_full:
                print(f"  {dv} (OECD*): domestic_Z_1_dev_std={res_full['coefs']['domestic_Z_1_dev_std']:.3f} "
                      f"(p={res_full['pvals']['domestic_Z_1_dev_std']:.3f})")
        else:
            for label, res in [(f"{dv} (Full)", res_full), (f"{dv} (OECD)", res_oecd)]:
                results_t10.append(res)
                col_labels_t10.append(label)

    row_vars_t10 = list(dict.fromkeys(domestic_z_std + controls_rate + controls_infl))

    footer_t10 = [
        ("R²", [f"{r['r2']:.3f}" if r else "" for r in results_t10]),
        ("N obs", [f"{r['nobs']:,}" if r else "" for r in results_t10]),
        ("N countries", [f"{r['ncountries']}" if r else "" for r in results_t10]),
    ]

    md_t10 = build_markdown_table(
        "Table 10: Domestic Demographics (Standardized) on Rates and Inflation",
        row_vars_t10, col_labels_t10, results_t10, footer_t10,
    )
    md_t10 += "\n\n*Coefficients represent per 1-SD change in domestic demographic deviation. Year fixed effects absorb the common global demographic trend.*"
    print("\n" + md_t10)
    t10_path = TABLE_DIR / "phase6_table10_global_domestic.md"
    t10_path.write_text(md_t10)
    print(f"\nSaved: {t10_path}")

    # (B) With domestic_Z_1_dev_std × kaopen interaction
    print("\n── Table 10b: With Domestic Z_1 × KAOPEN Interaction ──")
    interaction_var = "domestic_Z_1_dev_std_x_kaopen"
    if interaction_var not in panel.columns:
        panel[interaction_var] = panel["domestic_Z_1_dev_std"] * panel["kaopen"]
        oecd[interaction_var] = oecd["domestic_Z_1_dev_std"] * oecd["kaopen"]

    results_t10b = []
    col_labels_t10b = []
    domestic_z_std_int = domestic_z_std + [interaction_var]

    for dv in dvs_rate:
        rhs = domestic_z_std_int + controls_rate
        res_full = run_model(panel, dv, rhs)
        res_oecd = run_model(oecd, dv, rhs)
        fn = res_full['nobs'] if res_full else 0
        on = res_oecd['nobs'] if res_oecd else 0
        if fn == on and fn > 0:
            results_t10b.append(res_full)
            col_labels_t10b.append(f"{dv} (OECD*)")
            if res_full and interaction_var in res_full["coefs"]:
                print(f"  {dv} OECD*: {interaction_var}={res_full['coefs'][interaction_var]:.3f} "
                      f"(p={res_full['pvals'][interaction_var]:.3f})")
        else:
            for label, res in [(f"{dv} (Full)", res_full), (f"{dv} (OECD)", res_oecd)]:
                results_t10b.append(res)
                col_labels_t10b.append(label)

    for dv in dvs_infl:
        rhs = domestic_z_std_int + controls_infl
        res_full = run_model(panel, dv, rhs)
        res_oecd = run_model(oecd, dv, rhs)
        fn = res_full['nobs'] if res_full else 0
        on = res_oecd['nobs'] if res_oecd else 0
        if fn == on and fn > 0:
            results_t10b.append(res_full)
            col_labels_t10b.append(f"{dv} (OECD*)")
        else:
            for label, res in [(f"{dv} (Full)", res_full), (f"{dv} (OECD)", res_oecd)]:
                results_t10b.append(res)
                col_labels_t10b.append(label)

    row_vars_t10b = list(dict.fromkeys(domestic_z_std_int + controls_rate + controls_infl))
    footer_t10b = [
        ("R²", [f"{r['r2']:.3f}" if r else "" for r in results_t10b]),
        ("N obs", [f"{r['nobs']:,}" if r else "" for r in results_t10b]),
        ("N countries", [f"{r['ncountries']}" if r else "" for r in results_t10b]),
    ]

    # Append interaction panel to the main table
    md_t10 += "\n\n"
    md_t10 += build_markdown_table(
        "Table 10b: With Domestic Z_1 × KAOPEN Interaction",
        row_vars_t10b, col_labels_t10b, results_t10b, footer_t10b,
    )
    t10_path.write_text(md_t10)
    print(f"\nUpdated: {t10_path}")

    # ── TABLE 11: Time-Series World Average ───────────────────────────────
    # NOTE: Z polynomials are constructed from demeaned age shares, so their
    # GDP-weighted global mean is numerically zero. For the world time-series
    # regression, we use GDP-weighted old_dep and working_age_share which
    # have meaningful time variation.
    print("\n── Table 11: Time-Series World Average ──")

    vars_to_collapse = [
        "real_bond_10y", "real_short_3m", "inflation",
        "old_dep", "working_age_share",
        "rgdp_growth",
    ]

    world_rows = []
    for yr, grp in panel.groupby("year"):
        row = {"year": yr}
        for var in vars_to_collapse:
            sub = grp[["ngdp_usd", var]].dropna()
            if len(sub) > 10 and sub["ngdp_usd"].sum() > 0:
                row[var] = np.average(sub[var], weights=sub["ngdp_usd"])
            else:
                row[var] = np.nan
        world_rows.append(row)

    world = pd.DataFrame(world_rows).sort_values("year").reset_index(drop=True)
    print(f"World time series: {len(world)} years ({world['year'].min()}-{world['year'].max()})")
    print(f"  real_bond_10y non-missing: {world['real_bond_10y'].notna().sum()}")
    print(f"  inflation non-missing:     {world['inflation'].notna().sum()}")

    # Standardize demographic variables for interpretable coefficients
    for var in ["old_dep", "working_age_share"]:
        sd = world[var].std()
        if sd > 0:
            world[f"{var}_std"] = world[var] / sd
            print(f"  {var}: SD={sd:.4f}, standardized")

    # OLS regressions with Newey-West HAC SEs
    z_world = ["old_dep_std", "working_age_share_std"]

    results_t11 = []
    col_labels_t11 = []
    rhs_labels_t11 = []

    # (1) real_bond_10y ~ demographics + growth + inflation
    dv = "real_bond_10y"
    rhs_rate_ts = z_world + ["rgdp_growth", "inflation"]
    rhs_labels_t11 = list(dict.fromkeys(rhs_rate_ts))
    sub = world[["year", dv] + rhs_rate_ts].dropna()
    if len(sub) >= 10:
        y = sub[dv].values
        X = np.column_stack([sub[v].values for v in rhs_rate_ts])
        X = np.column_stack([np.ones(len(y)), X])
        ols = run_ols_hac(y, X)
        res_dict = {
            "coefs": dict(zip(rhs_rate_ts, ols["coefs"][1:])),
            "se": dict(zip(rhs_rate_ts, ols["se"][1:])),
            "pvals": dict(zip(rhs_rate_ts, ols["pvals"][1:])),
            "r2": ols["r2"],
            "nobs": ols["nobs"],
        }
        results_t11.append(res_dict)
        col_labels_t11.append("10y rate (World)")
        print(f"  {dv}: old_dep_std={res_dict['coefs']['old_dep_std']:.3f} "
              f"(p={res_dict['pvals']['old_dep_std']:.3f}), R²={ols['r2']:.3f}, N={ols['nobs']}")
    else:
        results_t11.append(None)
        col_labels_t11.append("10y rate (World)")
        print(f"  {dv}: insufficient obs ({len(sub)})")

    # (2) real_short_3m ~ demographics + growth + inflation
    dv = "real_short_3m"
    sub = world[["year", dv] + rhs_rate_ts].dropna()
    if len(sub) >= 10:
        y = sub[dv].values
        X = np.column_stack([sub[v].values for v in rhs_rate_ts])
        X = np.column_stack([np.ones(len(y)), X])
        ols = run_ols_hac(y, X)
        res_dict = {
            "coefs": dict(zip(rhs_rate_ts, ols["coefs"][1:])),
            "se": dict(zip(rhs_rate_ts, ols["se"][1:])),
            "pvals": dict(zip(rhs_rate_ts, ols["pvals"][1:])),
            "r2": ols["r2"],
            "nobs": ols["nobs"],
        }
        results_t11.append(res_dict)
        col_labels_t11.append("3m rate (World)")
        print(f"  {dv}: old_dep_std={res_dict['coefs']['old_dep_std']:.3f} "
              f"(p={res_dict['pvals']['old_dep_std']:.3f}), R²={ols['r2']:.3f}")
    else:
        results_t11.append(None)
        col_labels_t11.append("3m rate (World)")

    # (3) inflation ~ demographics + growth
    dv = "inflation"
    rhs_infl_ts = z_world + ["rgdp_growth"]
    rhs_labels_t11 = list(dict.fromkeys(rhs_labels_t11 + rhs_infl_ts))
    sub = world[["year", dv] + rhs_infl_ts].dropna()
    if len(sub) >= 10:
        y = sub[dv].values
        X = np.column_stack([sub[v].values for v in rhs_infl_ts])
        X = np.column_stack([np.ones(len(y)), X])
        ols = run_ols_hac(y, X)
        res_dict = {
            "coefs": dict(zip(rhs_infl_ts, ols["coefs"][1:])),
            "se": dict(zip(rhs_infl_ts, ols["se"][1:])),
            "pvals": dict(zip(rhs_infl_ts, ols["pvals"][1:])),
            "r2": ols["r2"],
            "nobs": ols["nobs"],
        }
        results_t11.append(res_dict)
        col_labels_t11.append("Inflation (World)")
        print(f"  {dv}: old_dep_std={res_dict['coefs']['old_dep_std']:.3f} "
              f"(p={res_dict['pvals']['old_dep_std']:.3f}), R²={ols['r2']:.3f}")
    else:
        results_t11.append(None)
        col_labels_t11.append("Inflation (World)")

    # Build Table 11 markdown
    footer_t11 = [
        ("R²", [f"{r['r2']:.3f}" if r else "" for r in results_t11]),
        ("N years", [f"{r['nobs']}" if r else "" for r in results_t11]),
        ("SE", ["Newey-West HAC"] * len(results_t11)),
    ]

    md_t11 = build_markdown_table(
        "Table 11: World Average Time-Series Regressions (OLS, Newey-West HAC)",
        rhs_labels_t11, col_labels_t11, results_t11, footer_t11,
    )
    md_t11 += "\n\n*GDP-weighted world averages. Demographic variables (old_dep, working_age_share) standardized to per-1-SD units.*"
    print("\n" + md_t11)
    t11_path = TABLE_DIR / "phase6_table11_world_ts.md"
    t11_path.write_text(md_t11)
    print(f"\nSaved: {t11_path}")

    # ── Summary ──────────────────────────────────────────────────────────
    print("\n── Key Findings ──")

    # Domestic Z_1 on 10y rates
    res_10y = results_t10[0]
    if res_10y and "domestic_Z_1_dev_std" in res_10y["coefs"]:
        d1 = res_10y["coefs"]["domestic_Z_1_dev_std"]
        dp = res_10y["pvals"]["domestic_Z_1_dev_std"]
        print(f"  10y rate: domestic_Z_1_dev_std={d1:.3f} (p={dp:.3f})")
        print(f"    -> Per 1-SD domestic demographic deviation")

    # Inflation
    for res, lbl in zip(results_t10, col_labels_t10):
        if res and "nflation" in lbl and "domestic_Z_1_dev_std" in res["coefs"]:
            d1 = res["coefs"]["domestic_Z_1_dev_std"]
            dp = res["pvals"]["domestic_Z_1_dev_std"]
            print(f"  Inflation: domestic_Z_1_dev_std={d1:.3f} (p={dp:.3f})")
            break

    print("\nPhase 6 complete.")


if __name__ == "__main__":
    main()
